# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse
import os
import time
from logging import getLogger

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import json

from src.utils import (
    bool_flag,
    initialize_exp,
    restart_from_checkpoint,
    fix_random_seeds,
    AverageMeter,
    init_distributed_mode,
    accuracy,
)
import src.resnet50 as resnet_models

logger = getLogger()


parser = argparse.ArgumentParser(description="Evaluate models: Linear classification on ImageNet")

#########################
#### main parameters ####
#########################
parser.add_argument("--dump_path", type=str, default=".",
                    help="experiment dump path for checkpoints and log")
parser.add_argument("--seed", type=int, default=31, help="seed")
parser.add_argument("--data_path", type=str, default="/path/to/imagenet",
                    help="path to dataset repository")
parser.add_argument("--workers", default=8, type=int,
                    help="number of data loading workers")

#########################
#### model parameters ###
#########################
parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture")
parser.add_argument("--pretrained", default="", type=str, help="path to pretrained weights")
parser.add_argument("--global_pooling", default=True, type=bool_flag,
                    help="if True, we use the resnet50 global average pooling")
parser.add_argument("--use_bn", default=False, type=bool_flag,
                    help="optionally add a batchnorm layer before the linear classifier")

#########################
#### optim parameters ###
#########################
parser.add_argument("--epochs", default=100, type=int,
                    help="number of total epochs to run")
parser.add_argument("--batch_size", default=32, type=int,
                    help="batch size per gpu, i.e. how many unique instances per gpu")
parser.add_argument("--lr", default=0.3, type=float, help="initial learning rate")
parser.add_argument("--wd", default=1e-6, type=float, help="weight decay")
parser.add_argument("--nesterov", default=False, type=bool_flag, help="nesterov momentum")
parser.add_argument("--scheduler_type", default="cosine", type=str, choices=["step", "cosine"])
# for multi-step learning rate decay
parser.add_argument("--decay_epochs", type=int, nargs="+", default=[60, 80],
                    help="Epochs at which to decay learning rate.")
parser.add_argument("--gamma", type=float, default=0.1, help="decay factor")
# for cosine learning rate schedule
parser.add_argument("--final_lr", type=float, default=0, help="final learning rate")

#########################
#### dist parameters ###
#########################
parser.add_argument("--dist_url", default="env://", type=str,
                    help="url used to set up distributed training")
parser.add_argument("--world_size", default=-1, type=int, help="""
                    number of processes: it is set automatically and
                    should not be passed as argument""")
parser.add_argument("--rank", default=0, type=int, help="""rank of this process:
                    it is set automatically and should not be passed as argument""")
parser.add_argument("--local_rank", default=0, type=int,
                    help="this argument is not used and should be ignored")

parser.add_argument("--num_classes", default=1000, type=int)
parser.add_argument("--eval", action='store_true')
parser.add_argument("--dataset", default="in1k", type=str)



def get_objnet_mappings(val_loader):
    mappings_folder = '/datasets/objectnet-1.0/mappings/'
    with open(os.path.join(mappings_folder, "objectnet_to_imagenet_1k.json")) as file_handle:
        o_label_to_all_i_labels = json.load(file_handle)

    # now remove double i labels to avoid confusion
    o_label_to_i_labels = {
        o_label: all_i_label.split("; ")
        for o_label, all_i_label in o_label_to_all_i_labels.items()
    }

    # some in-between mappings ...
    o_folder_to_o_idx = val_loader.dataset.class_to_idx
    with open(os.path.join(mappings_folder, "folder_to_objectnet_label.json")) as file_handle:
        o_folder_o_label = json.load(file_handle)

    # now get mapping from o_label to o_idx
    o_label_to_o_idx = {
        o_label: o_folder_to_o_idx[o_folder]
        for o_folder, o_label in o_folder_o_label.items()
    }

    # some in-between mappings ...
    with open(os.path.join(mappings_folder, "pytorch_to_imagenet_2012_id.json")) as file_handle:
        i_idx_to_i_line = json.load(file_handle)
    with open(os.path.join(mappings_folder, "imagenet_to_label_2012_v2")) as file_handle:
        i_line_to_i_label = file_handle.readlines()

    i_line_to_i_label = {
        i_line: i_label[:-1]
        for i_line, i_label in enumerate(i_line_to_i_label)
    }

    # now get mapping from i_label to i_idx
    i_label_to_i_idx = {
        i_line_to_i_label[i_line]: int(i_idx)
        for i_idx, i_line in i_idx_to_i_line.items()
    }

    # now get the final mapping of interest!!!
    o_idx_to_i_idxs = {
        o_label_to_o_idx[o_label]: [
            i_label_to_i_idx[i_label] for i_label in i_labels
        ]
        for o_label, i_labels in o_label_to_i_labels.items()
    }
    i_idx_to_o_idxs = {}
    for k,v in o_idx_to_i_idxs.items():
        for v2 in v:
            i_idx_to_o_idxs[int(v2)] = int(k)
    return i_idx_to_o_idxs, o_idx_to_i_idxs

@torch.no_grad()
def accuracy2(output, target, in1k2objnet, objnet2in1k):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    pred = output.argmax(-1)
    pred = pred.squeeze()
    target = target.squeeze()

    pred2 = []
    for i in range(len(pred)):
        pred2.append(in1k2objnet[pred[i].item()]  if pred[i].item() in in1k2objnet else 2000)
    pred = torch.tensor(pred2, device='cuda')
    for i in range(len(target)):
        target[i] = target[i] if target[i].item() in objnet2in1k else -1
    target_valid = target.view(-1, 1)[target>=0].squeeze()
    pred = pred.view(-1,1)[target>=0].squeeze()
    correct = (pred == target_valid).unsqueeze(-1)
    res = []
    # print(target_valid, pred, sum(correct), flush=True)
    if sum(target>=0) > 0:
        res.append([sum(correct).item()*(100.0 / sum(target>=0) )])
        # print("================", res[0],  flush=True)
    else:
        res.append([0])
    res.append(sum(target>=0))
    return res

def main():
    global args, best_acc
    args = parser.parse_args()
    init_distributed_mode(args)
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(
        args, "epoch", "loss", "prec1", "prec5", "loss_val", "prec1_val", "prec5_val"
    )
    tr_normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]
    )
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        tr_normalize,
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        tr_normalize,
    ])
    # build data
    if args.dataset in ['cifar10', 'cifar100', 'objnet']:
        if args.dataset == 'objnet':
            train_dataset = datasets.ImageFolder(args.data_path, val_transform)
            val_dataset = datasets.ImageFolder(args.data_path, val_transform)
            # set falgs
            args.evaluate = True
        else:
            if args.dataset == 'cifar10':
                train_dataset = datasets.CIFAR10(args.data_path, train=True, transform=train_transform)
                val_dataset = datasets.CIFAR10(args.data_path, train=False, transform=val_transform)
            elif args.dataset == 'cifar100':
                train_dataset = datasets.CIFAR100(args.data_path, train=True, transform=train_transform)
                val_dataset = datasets.CIFAR100(args.data_path, train=False, transform=val_transform)
    else:
        train_dataset = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
        val_dataset = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)


    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
    )
    logger.info("Building data done")

    # build model
    model = resnet_models.__dict__[args.arch](output_dim=0, eval_mode=True)
    linear_classifier = RegLog(args.num_classes, args.arch, args.global_pooling, args.use_bn)

    # convert batch norm layers (if any)
    linear_classifier = nn.SyncBatchNorm.convert_sync_batchnorm(linear_classifier)

    # model to gpu
    model = model.cuda()
    linear_classifier = linear_classifier.cuda()
    linear_classifier = nn.parallel.DistributedDataParallel(
        linear_classifier,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )
    model.eval()

    # load weights
    if args.pretrained:
        if os.path.isfile(args.pretrained):
            print("=> loading checkpoint '{}'".format(args.pretrained))
            checkpoint = torch.load(args.pretrained, map_location="cpu")

            # rename moco pre-trained keys
            if 'state_dict' in checkpoint.keys():
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint
            if any(['module.encoder_q' in k for k in state_dict.keys()]):
                print('moco model')
                for k in list(state_dict.keys()):
                    # retain only encoder_q up to before the embedding layer
                    if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
                        # remove prefix
                        state_dict[k[len("module.encoder_q."):]] = state_dict[k]
                    # delete renamed or unused k
                    del state_dict[k]

                args.start_epoch = 0
                msg = model.load_state_dict(state_dict, strict=False)
                pretrained_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 999
                print(msg)
                # assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
            elif 'teacher' in checkpoint:
                state_dict = checkpoint['teacher']
                print('dino type model')
                state_dict = {k.replace('module.backbone.',''):v for k,v in state_dict.items()}
                for k in list(state_dict.keys()):
                    # retain only encoder_q up to before the embedding layer
                    if k.startswith('projection_head') or k.startswith('prototypes') or k.startswith('fc'):
                        # remove prefix
                        del state_dict[k]

                args.start_epoch = 0
                msg = model.load_state_dict(state_dict, strict=False)
                pretrained_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 999
            else:
                print('swav/DC model')
                state_dict = {k.replace('module.',''):v for k,v in state_dict.items()}
                for k in list(state_dict.keys()):
                    # retain only encoder_q up to before the embedding layer
                    if k.startswith('projection_head') or k.startswith('prototypes') or k.startswith('fc'):
                        # remove prefix
                        del state_dict[k]

                args.start_epoch = 0
                msg = model.load_state_dict(state_dict, strict=False)
                pretrained_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 999
                # assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}

            print("=> loaded pre-trained model '{}'".format(args.pretrained))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrained))
            pretrained_epoch = 0

    # set optimizer
    optimizer = torch.optim.SGD(
        linear_classifier.parameters(),
        lr=args.lr,
        nesterov=args.nesterov,
        momentum=0.9,
        weight_decay=args.wd,
    )

    # set scheduler
    if args.scheduler_type == "step":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, args.decay_epochs, gamma=args.gamma
        )
    elif args.scheduler_type == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.epochs, eta_min=args.final_lr
        )

    # Optionally resume from a checkpoint
    to_restore = {"epoch": 0, "best_acc": 0.}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=linear_classifier,
        optimizer=optimizer,
        scheduler=scheduler,
    )
    start_epoch = to_restore["epoch"]
    best_acc = to_restore["best_acc"]
    cudnn.benchmark = True

    if args.eval:
        scores_val = validate_network(val_loader, model, linear_classifier, args.dataset)
        print(scores_val)
        return

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set samplers
        train_loader.sampler.set_epoch(epoch)

        scores = train(model, linear_classifier, optimizer, train_loader, epoch)
        scores_val = validate_network(val_loader, model, linear_classifier, args.dataset)
        training_stats.update(scores + scores_val)

        scheduler.step()

        # save checkpoint
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": linear_classifier.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "best_acc": best_acc,
            }
            torch.save(save_dict, os.path.join(args.dump_path, "checkpoint.pth.tar"))
    logger.info("Training of the supervised linear classifier on frozen features completed.\n"
                "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc))


class RegLog(nn.Module):
    """Creates logistic regression on top of frozen features"""

    def __init__(self, num_labels, arch="resnet50", global_avg=False, use_bn=True):
        super(RegLog, self).__init__()
        self.bn = None
        if global_avg:
            if arch == "resnet50":
                s = 2048
            elif arch == "resnet50w2":
                s = 4096
            elif arch == "resnet50w4":
                s = 8192
            self.av_pool = nn.AdaptiveAvgPool2d((1, 1))
        else:
            assert arch == "resnet50"
            s = 8192
            self.av_pool = nn.AvgPool2d(6, stride=1)
            if use_bn:
                self.bn = nn.BatchNorm2d(2048)
        self.linear = nn.Linear(s, num_labels)
        self.linear.weight.data.normal_(mean=0.0, std=0.01)
        self.linear.bias.data.zero_()

    def forward(self, x):
        # average pool the final feature map
        x = self.av_pool(x)

        # optional BN
        if self.bn is not None:
            x = self.bn(x)

        # flatten
        x = x.view(x.size(0), -1)

        # linear layer
        return self.linear(x)


def train(model, reglog, optimizer, loader, epoch):
    """
    Train the models on the dataset.
    """
    # running statistics
    batch_time = AverageMeter()
    data_time = AverageMeter()

    # training statistics
    top1 = AverageMeter()
    top5 = AverageMeter()
    losses = AverageMeter()
    end = time.perf_counter()

    model.eval()
    reglog.train()
    criterion = nn.CrossEntropyLoss().cuda()

    for iter_epoch, (inp, target) in enumerate(loader):
        # measure data loading time
        data_time.update(time.perf_counter() - end)

        # move to gpu
        inp = inp.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # forward
        with torch.no_grad():
            output = model(inp)
        output = reglog(output)

        # compute cross entropy loss
        loss = criterion(output, target)

        # compute the gradients
        optimizer.zero_grad()
        loss.backward()

        # step
        optimizer.step()

        # update stats
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), inp.size(0))
        top1.update(acc1[0], inp.size(0))
        top5.update(acc5[0], inp.size(0))

        batch_time.update(time.perf_counter() - end)
        end = time.perf_counter()

        # verbose
        if args.rank == 0 and iter_epoch % 50 == 0:
            logger.info(
                "Epoch[{0}] - Iter: [{1}/{2}]\t"
                "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                "Prec {top1.val:.3f} ({top1.avg:.3f})\t"
                "LR {lr}".format(
                    epoch,
                    iter_epoch,
                    len(loader),
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses,
                    top1=top1,
                    lr=optimizer.param_groups[0]["lr"],
                )
            )

    return epoch, losses.avg, top1.avg.item(), top5.avg.item()


def validate_network(val_loader, model, linear_classifier, dataset):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    global best_acc
    if dataset == 'objnet':
        print('running eval on objnet!!', flush=True)
        in1k2objnet,  objnet2in1k = get_objnet_mappings(val_loader)

    # switch to evaluate mode
    model.eval()
    linear_classifier.eval()

    criterion = nn.CrossEntropyLoss().cuda()

    with torch.no_grad():
        end = time.perf_counter()
        for i, (inp, target) in enumerate(val_loader):

            # move to gpu
            inp = inp.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            # compute output
            output = linear_classifier(model(inp))
            loss = criterion(output, target)

            batch_size = inp.size(0)
            if dataset != 'objnet':
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
            else:
                acc1, len_valid = accuracy2(output, target, in1k2objnet, objnet2in1k)
                batch_size = len_valid
                if len_valid > 0:
                    acc5 = torch.tensor([0])
            losses.update(loss.item(), inp.size(0))
            if batch_size > 0:
                top1.update(acc1[0], inp.size(0))
                top5.update(acc5[0], inp.size(0))

            # measure elapsed time
            batch_time.update(time.perf_counter() - end)
            end = time.perf_counter()

    if top1.avg.item() > best_acc:
        best_acc = top1.avg.item()

    if args.rank == 0:
        logger.info(
            "Test:\t"
            "Time {batch_time.avg:.3f}\t"
            "Loss {loss.avg:.4f}\t"
            "Acc@1 {top1.avg:.3f}\t"
            "Best Acc@1 so far {acc:.1f}".format(
                batch_time=batch_time, loss=losses, top1=top1, acc=best_acc))

    return losses.avg, top1.avg.item(), top5.avg.item()


if __name__ == "__main__":
    main()
